-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53803][ML][Feature] Added ArimaRegression for time series forecasting in MLlib #52519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…sting in MLlib Add ArimaRegression for time series forecasting in MLlib
@sryza Could you review this commit |
|
||
override def fit(dataset: Dataset[_]): ArimaRegressionModel = { | ||
// Dummy: assumes data is ordered with one feature column "y" | ||
val ts = dataset.select("y").rdd.map(_.getDouble(0)).collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this implementation needs to collect all the dataset, then I doesn't think it match the requirement
Be highly scalable
from the MLlib-specific contribution guidelines
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I think a more mature usage is using https://www.statsmodels.org/stable/generated/statsmodels.tsa.arima.model.ARIMA.html with pyspark
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point @zhengruifeng — agreed that .collect() breaks scalability.
I’ll refactor the logic to keep computations distributed (likely using mapPartitions or partition-level fitting).
Thanks for pointing out the statsmodels ARIMA example — I’ll review that for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ll refactor the logic to keep computations distributed
@anandexplore before working on it, let's wait for more input from @WeichenXu123 @srowen @mengxr on the need of new algorithm
What changes were proposed in this pull request?
This pull request adds a new feature called ArimaRegression to Spark MLlib under org.apache.spark.ml.regression.
It brings the ARIMA (AutoRegressive Integrated Moving Average) model for one-variable (univariate) time series forecasting, along with a matching model class ArimaRegressionModel.
The update includes:
Scala code for ArimaRegression and ArimaRegressionModel
Support for ARIMA parameters: p, d, and q
PySpark API bindings for both classes
Unit tests in Scala and Python
Model save/load support using MLWritable and MLReadable
Example usage in examples/ml/ArimaRegressionExample.scala
Why are the changes needed?
Currently, Spark MLlib does not have built-in tools for time-series forecasting.
ARIMA is one of the most common models for predicting trends in time-based data.
Adding this feature allows Spark users to perform forecasting directly within MLlib, without needing outside Python libraries. It also makes Spark’s machine learning toolkit more complete.
Does this PR introduce any user-facing change?
Yes.
New APIs are available in both Scala and Python:
org.apache.spark.ml.regression.ArimaRegression
org.apache.spark.ml.regression.ArimaRegressionModel
pyspark.ml.regression.ArimaRegression
pyspark.ml.regression.ArimaRegressionModel
These follow standard Spark ML APIs and work with Pipelines, ParamMaps, save/load, and transform().
How was this patch tested?
Tests were added in:
Scala (ArimaRegressionSuite.scala) for:
Model fitting and transforming
Parameter defaults and setters
Save/load functions
Python (test_regression.py) for PySpark interface
Manual testing was also done in both:
spark-shell (Scala)
pyspark (Python)
Manual Tested
Scala:
import org.apache.spark.ml.regression.ArimaRegression
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder().appName("ArimaRegressionExample").getOrCreate()
import spark.implicits._
val data = Seq(100.0, 102.5, 101.0, 104.0, 107.5, 110.0).toDF("value")
val arima = new ArimaRegression()
.setP(1)
.setD(1)
.setQ(1)
val model = arima.fit(data)
val forecast = model.transform(data)
forecast.show(false)
Python:
from pyspark.ml.regression import ArimaRegression
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ArimaRegressionExample").getOrCreate()
data = [(100.0,), (102.5,), (101.0,), (104.0,), (107.5,), (110.0,)]
df = spark.createDataFrame(data, ["value"])
arima = ArimaRegression(p=1, d=1, q=1)
model = arima.fit(df)
forecast = model.transform(df)
forecast.show(truncate=False)
Predictions and output schema were checked for correctness
Was this patch authored or co-authored using generative AI tooling?
No.